{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run upon export from spreadsheet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "f = '/mnt/tess/labels/triage-v14.csv'\n",
    "\n",
    "t = pd.read_csv(f, header=0, low_memory=False).set_index('Astro ID')\n",
    "\n",
    "disps = ['E', 'J', 'N', 'S', 'B']\n",
    "users = ['av', 'md', 'ch', 'as', 'mk', 'et', 'dm', 'td']\n",
    "for d in disps:\n",
    "  t[f'disp_{d}'] = 0\n",
    "\n",
    "def set_labels(row):\n",
    "  a = ~row.isna()\n",
    "  if a['Final'] and row['Final'] in disps:\n",
    "    row[f'disp_{row[\"Final\"]}'] = 1\n",
    "  else:\n",
    "    for user in users:\n",
    "      if a[user] and row[user] in disps:\n",
    "        row[f'disp_{row[user]}'] += 1\n",
    "  return row\n",
    "t = t.apply(set_labels, axis=1)\n",
    "\n",
    "t_train = t[t['Split'] == 'train']\n",
    "t_val = t[t['Split'] == 'val']\n",
    "t_test = t[t['Split'] == 'test']\n",
    "\n",
    "assert not any((t_train['disp_E'] + t_train['disp_J']+ t_train['disp_N'] + t_train['disp_S'] + t_train['disp_B']) == 0)\n",
    "assert not any((t_val['disp_E'] + t_val['disp_J']+ t_val['disp_N'] + t_val['disp_S'] + t_val['disp_B']) == 0)\n",
    "assert not any((t_test['disp_E'] + t_test['disp_J']+ t_test['disp_N'] + t_test['disp_S'] + t_test['disp_B']) == 0)\n",
    "\n",
    "print('Splits')\n",
    "print('  train:', len(t_train))\n",
    "print('  val:', len(t_val))\n",
    "print('  test:', len(t_test))\n",
    "\n",
    "t_train.to_csv('/mnt/tess/astronet/tces-v14-train.csv')\n",
    "t_val.to_csv('/mnt/tess/astronet/tces-v14-val.csv')\n",
    "t_test.to_csv('/mnt/tess/astronet/tces-v14-test.csv')\n",
    "t.to_csv('/mnt/tess/astronet/tces-v14-all.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_columns', None)\n",
    "t_train.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_val.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_test.sample(5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
